from rich import print
from typing import List
from langchain_openai import OpenAIEmbeddings
from langchain_chroma import Chroma
from langchain.docstore.document import Document
from sentence_transformers import SentenceTransformer
from langchain_huggingface import HuggingFaceEmbeddings


class VectorReplayBuffer:
    def __init__(self, data_path: str) -> None:
        self.embedding = OpenAIEmbeddings()
        # self.embedding = HuggingFaceEmbeddings(model="sentence-transformers/all-mpnet-base-v2")
        self.data_path = data_path
        self.memory = Chroma(
            embedding_function=self.embedding,
            persist_directory=self.data_path
        )
        print("=== VectorExperienceReplay === ", data_path, ", Now the database has ", len(
            self.memory._collection.get(include=['embeddings'])['embeddings']), " items.==========")
    
    def add(self, state, obs, reasoning, action, comments="") -> int:
        results = self.memory._collection.get(where_document={"$contains": state})
        state_exists = len(results['ids']) > 0
        if state_exists:
            print("State already exists in the memory. Skipping...")
            return -1
        else:
            doc = Document(
                page_content=state,
                metadata={'obs': str(obs),
                          'reasoning': str(reasoning),
                          'action': str(action),
                          'comments': str(comments),
                          },
            )
            id = self.memory.add_documents([doc])
            print("State added to the memory with id: ", id, ". Now the database has ", len(
                self.memory._collection.get(include=['embeddings'])['embeddings']), " items.")
            return id

    def retrive(self, state, k: int = 1) -> List[str]:
        similar_results = self.memory.similarity_search_with_score(state, k=k)
        # results are in the format (Document, score)
        results = [result[0].metadata for result in similar_results]
        return results

    def delete(self, ids: List[int]) -> bool:
        try:
            self.memory._collection.delete(ids=id)
            print("Delete", len(ids), "memory items. Now the database has ", len(
            self.memory._collection.get(include=['embeddings'])['embeddings']), " items.")
            return True
        except:
            return False
    
    def combine(self, other_memory):
        other_memory_items = other_memory._collection.get(include=['documens', 'metadatas', 'embeddings'])
        current_memory_items = self.memory._collection.get(include=['documens', 'metadatas', 'embeddings'])
        # TODO: check the type of other memory that needs to be combined
        for item in other_memory_items:
            if item['embedding'] in current_memory_items['embeddings']:
                print("State already exists in the memory. Skipping...")
            else:
                self.memory._collection.add(
                    embedding=item['embedding'],
                    metadatas=item['metadatas'],
                    documents=item['documents'],
                    ids=item['ids']
                    )
        print("Merge complete. Now the database has ", len(self.scenario_memory._collection.get(include=['embeddings'])['embeddings']), " items.")
        

if __name__ == "__main__":
    data_path = "memory_database/test_database"
    vector_memory = VectorReplayBuffer(data_path)
    from IPython import embed; embed()
